from fuzzywuzzy import fuzz

class Evaluator:
    def __init__(self, api_client):
        self.api_client = api_client

    async def extract_triples(self, document_text):
        prompt = f"""
        Extract all (subject, relation, object) triples from the following document:

        Document:
        {document_text}

        Format:
        (subject ~ relation ~ object)
        """
        response = await self.api_client.call("gpt-4o-mini-2024-07-18", prompt, for_graph=True)
        return self._parse_triples(response)

    def _parse_triples(self, text):
        import re
        pattern = r'\(([^~]+)~([^~]+)~([^)]+)\)'
        triples = [(a.strip(), b.strip(), c.strip()) for a, b, c in re.findall(pattern, text)]
        return triples

    def match_triples(self, generated_triples, ground_truth_triples, threshold=80):
        matches, mismatches = [], []

        for triple in generated_triples:
            best_match = max(ground_truth_triples, key=lambda gt:
                             fuzz.ratio(triple[0], gt[0]) + fuzz.ratio(triple[1], gt[1]) + fuzz.ratio(triple[2], gt[2]),
                             default=None)
            if best_match:
                score = (fuzz.ratio(triple[0], best_match[0]) +
                         fuzz.ratio(triple[1], best_match[1]) +
                         fuzz.ratio(triple[2], best_match[2])) / 3

                if score >= threshold:
                    matches.append((triple, best_match, score))
                else:
                    mismatches.append(triple)
        return matches, mismatches

    def compute_metrics(self, generated_triples, ground_truth_triples, matches):
        precision = len(matches) / len(generated_triples) if generated_triples else 0.0
        recall = len(matches) / len(ground_truth_triples) if ground_truth_triples else 0.0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        return {"precision": precision, "recall": recall, "f1": f1}
